{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "initial_id",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-06-29T05:17:23.183479Z",
     "start_time": "2025-06-29T05:17:20.503755Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 简洁实现门控循环单元GRU\n",
    "import torch\n",
    "from torch import nn\n",
    "from d2l import torch as d2l\n",
    "\n",
    "batch_size, num_steps = 32, 35\n",
    "train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8b7813f5c15c9a14",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-06-29T05:17:43.807328Z",
     "start_time": "2025-06-29T05:17:43.795369Z"
    }
   },
   "outputs": [],
   "source": [
    "def get_params(vocab_size, num_hiddens, device): #获取所有参数\n",
    "    num_inputs = num_outputs = vocab_size\n",
    "\n",
    "    def normal(shape):\n",
    "        return torch.randn(size=shape, device=device)*0.01\n",
    "\n",
    "    def three():\n",
    "        return (normal((num_inputs, num_hiddens)),\n",
    "                normal((num_hiddens, num_hiddens)),\n",
    "                torch.zeros(num_hiddens, device=device))\n",
    "\n",
    "    W_xz, W_hz, b_z = three()  # 更新门参数\n",
    "    W_xr, W_hr, b_r = three()  # 重置门参数\n",
    "    W_xh, W_hh, b_h = three()  # 候选隐状态参数\n",
    "    # 输出层参数\n",
    "    W_hq = normal((num_hiddens, num_outputs))\n",
    "    b_q = torch.zeros(num_outputs, device=device)\n",
    "    # 附加梯度\n",
    "    params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]\n",
    "    for param in params:\n",
    "        param.requires_grad_(True)\n",
    "    return params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "62fdb56f71185579",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-06-29T05:18:08.679140Z",
     "start_time": "2025-06-29T05:18:08.664636Z"
    }
   },
   "outputs": [],
   "source": [
    "def init_gru_state(batch_size, num_hiddens, device):\n",
    "    return (torch.zeros((batch_size, num_hiddens), device=device), )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c9baee0ea40fd211",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-06-29T05:19:17.379186Z",
     "start_time": "2025-06-29T05:19:17.364448Z"
    }
   },
   "outputs": [],
   "source": [
    "def gru(inputs, state, params):\n",
    "    W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params\n",
    "    H, = state\n",
    "    outputs = []\n",
    "    for X in inputs:\n",
    "        Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)\n",
    "        R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)\n",
    "        H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)\n",
    "        H = Z * H + (1 - Z) * H_tilda\n",
    "        Y = H @ W_hq + b_q\n",
    "        outputs.append(Y)\n",
    "    return torch.cat(outputs, dim=0), (H,)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "22d90c17833883b6",
   "metadata": {},
   "outputs": [
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[5], line 5\u001b[0m\n\u001b[0;32m      2\u001b[0m num_epochs, lr \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m500\u001b[39m, \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m      3\u001b[0m model \u001b[38;5;241m=\u001b[39m d2l\u001b[38;5;241m.\u001b[39mRNNModelScratch(\u001b[38;5;28mlen\u001b[39m(vocab), num_hiddens, device, get_params,\n\u001b[0;32m      4\u001b[0m                             init_gru_state, gru)\n\u001b[1;32m----> 5\u001b[0m \u001b[43md2l\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_ch8\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_iter\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvocab\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_epochs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32m~\\anaconda3\\envs\\d2l\\lib\\site-packages\\d2l\\torch.py:803\u001b[0m, in \u001b[0;36mtrain_ch8\u001b[1;34m(net, train_iter, vocab, lr, num_epochs, device, use_random_iter)\u001b[0m\n\u001b[0;32m    801\u001b[0m \u001b[38;5;66;03m# Train and predict\u001b[39;00m\n\u001b[0;32m    802\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(num_epochs):\n\u001b[1;32m--> 803\u001b[0m     ppl, speed \u001b[38;5;241m=\u001b[39m \u001b[43mtrain_epoch_ch8\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnet\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_iter\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mloss\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mupdater\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    804\u001b[0m \u001b[43m                                 \u001b[49m\u001b[43muse_random_iter\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    805\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m (epoch \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m) \u001b[38;5;241m%\u001b[39m \u001b[38;5;241m10\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m    806\u001b[0m         \u001b[38;5;28mprint\u001b[39m(predict(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtime traveller\u001b[39m\u001b[38;5;124m'\u001b[39m))\n",
      "File \u001b[1;32m~\\anaconda3\\envs\\d2l\\lib\\site-packages\\d2l\\torch.py:780\u001b[0m, in \u001b[0;36mtrain_epoch_ch8\u001b[1;34m(net, train_iter, loss, updater, device, use_random_iter)\u001b[0m\n\u001b[0;32m    778\u001b[0m     updater\u001b[38;5;241m.\u001b[39mstep()\n\u001b[0;32m    779\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m--> 780\u001b[0m     \u001b[43ml\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    781\u001b[0m     grad_clipping(net, \u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m    782\u001b[0m     \u001b[38;5;66;03m# Since the `mean` function has been invoked\u001b[39;00m\n",
      "File \u001b[1;32m~\\anaconda3\\envs\\d2l\\lib\\site-packages\\torch\\_tensor.py:648\u001b[0m, in \u001b[0;36mTensor.backward\u001b[1;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[0;32m    638\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[0;32m    639\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[0;32m    640\u001b[0m         Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[0;32m    641\u001b[0m         (\u001b[38;5;28mself\u001b[39m,),\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m    646\u001b[0m         inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[0;32m    647\u001b[0m     )\n\u001b[1;32m--> 648\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m    649\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\n\u001b[0;32m    650\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32m~\\anaconda3\\envs\\d2l\\lib\\site-packages\\torch\\autograd\\__init__.py:353\u001b[0m, in \u001b[0;36mbackward\u001b[1;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[0;32m    348\u001b[0m     retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[0;32m    350\u001b[0m \u001b[38;5;66;03m# The reason we repeat the same comment below is that\u001b[39;00m\n\u001b[0;32m    351\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[0;32m    352\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[1;32m--> 353\u001b[0m \u001b[43m_engine_run_backward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m    354\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    355\u001b[0m \u001b[43m    \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    356\u001b[0m \u001b[43m    \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    357\u001b[0m \u001b[43m    \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    358\u001b[0m \u001b[43m    \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    359\u001b[0m \u001b[43m    \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m    360\u001b[0m \u001b[43m    \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m    361\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32m~\\anaconda3\\envs\\d2l\\lib\\site-packages\\torch\\autograd\\graph.py:824\u001b[0m, in \u001b[0;36m_engine_run_backward\u001b[1;34m(t_outputs, *args, **kwargs)\u001b[0m\n\u001b[0;32m    822\u001b[0m     unregister_hooks \u001b[38;5;241m=\u001b[39m _register_logging_hooks_on_whole_graph(t_outputs)\n\u001b[0;32m    823\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m--> 824\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m Variable\u001b[38;5;241m.\u001b[39m_execution_engine\u001b[38;5;241m.\u001b[39mrun_backward(  \u001b[38;5;66;03m# Calls into the C++ engine to run the backward pass\u001b[39;00m\n\u001b[0;32m    825\u001b[0m         t_outputs, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs\n\u001b[0;32m    826\u001b[0m     )  \u001b[38;5;66;03m# Calls into the C++ engine to run the backward pass\u001b[39;00m\n\u001b[0;32m    827\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[0;32m    828\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m attach_logging_hooks:\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()\n",
    "num_epochs, lr = 500, 1\n",
    "model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params,\n",
    "                            init_gru_state, gru)\n",
    "d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f15cd2941e8643ca",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
